from QKDNetwork import QKDNetwork
from he_rps import LP
from he_psp import getPSPHETimeSlot
from load_balance import getLoadBalanceRouteList
import random

# Relay       |   Provision
# SPS(1)      |   Greedy(3)
# RPS-HE(2)   |   PSP-HE(4)

# 2 + 4 = RPSP-HE
# 2 + 3 = RPS-HE + GD
# 1 + 4 = SP + PSP-HE

# QKD请求的路由方案数据结构  [{"sd": [1, 2], "path": [1, 3, 2]}...]

class Compare:
    def __init__(self, net: QKDNetwork):
        self.net = net
        self.sortSD()
    
    def sortSD(self):
        # 按照 Demand 从大到小排序
        combined = sorted(zip(self.net.sd_list, self.net.candidate_paths), key=lambda x: x[0][2], reverse=True)
        self.net.sd_list, self.net.candidate_paths = zip(*combined)
        self.net.sd_list, self.net.candidate_paths = list(self.net.sd_list), list(self.net.candidate_paths)
    
    def getData(self):
        # return [0, self.rpd_gd(), self.rpsp()]
        return [self.sp_psp(), self.rpd_gd(), self.rpsp()]
        # print("来这了")
        # return [self.sp_gd(), self.sp_gd_with_spd()] # 用户测量削减SPD数量后的效果
    
    def shortest_path_spd(self) -> list:
        # print("来这了 shortest_path_spd")
        # 去spd路径中的最短路
        candidate_paths = []
        for tmp in self.net.candidate_paths_spd:
            candidate_paths.append({
                "sd": tmp["sd"],
                "path": tmp["paths"][0]
            })
        return candidate_paths

    def shortest_path(self) -> list: # ✅ SPS(1)
        # print("来这了 shortest_path")
        # 获取最短路方案
        # return [{"sd": [1, 2, 4], "path": [1, 3, 2]}...]
        candidate_paths = []
        for sd in self.net.sd_list:
            # 计算候选路径
            candidate_paths.append({
                "sd": sd,
                "path": self.net.getSPWithTransmitterWithNetworkX(sd[0], sd[1])
                # "path": self.net.getShortestPathWithTransmitter(sd[0], sd[1])
            })
        self.candidate_paths = candidate_paths
        return candidate_paths
    
    def cplex_path(self) -> list: # ✅ RPS-HE(2)
        # 获取 Cplex 求出来的路径
        # return [{"sd": [1, 2, 4], "path": [1, 3, 2]}...]
        return getLoadBalanceRouteList(self.net)
        cplex_lp = LP(self.net)
        ret = []
        for index in range(len(self.net.sd_list)):
            ret.append({
                "sd": self.net.sd_list[index],
                "path": self.net.candidate_paths[index]["paths"][cplex_lp.processed_selected_paths_index[index]]
            })
        return ret
    
    def greedy_ps_pro_with_spd(self, routeList) -> int: # ✅ Greedy(3)
        # print("来这了 greedy_ps_pro_with_spd")
        # 贪心供应
        # 输入：QKDNetwork、路由方案
        # 输出：Time Slot
        # 估计最小的完成时间隙数量，每个请求都按照最短路
        self.net.clearLoad()
        def addFlowOnPath(route): # 施加流
            path = route["path"]
            for i in range(len(path)-1):
                link_start = path[i]
                link_end = path[i + 1]
                self.net.G[link_start][link_end]["load"] += route["sd"][2]
        for tmpRoute in routeList:
            addFlowOnPath(tmpRoute)
        max_time_slot = 0
        while(True):
            flag = True
            for i in self.net.G.nodes: #初始化
                node = self.net.G.nodes[i]
                if(node["transmitter"] == 1):
                    # 选择一个需要密钥最多且配备有测量设备的链路进行服务
                    not_zero_link_node_ids = []
                    for neighbor in self.net.G.neighbors(i):
                        if(self.net.G[i][neighbor]["load"] > 0 and (((i, neighbor) in self.net.spd) or ((neighbor, i) in self.net.spd))):
                            not_zero_link_node_ids.append(neighbor)
                            flag = False
                    if(len(not_zero_link_node_ids) != 0):
                        random_neighbor = random.choice(not_zero_link_node_ids)
                        self.net.G[random_neighbor][i]["load"] -= node["transmitter_rate"]
            if(flag):
                break
            else:
                max_time_slot += 1
                if(max_time_slot >= 1000):
                    print(123456)
        return max_time_slot

    def greedy_ps_pro(self, routeList) -> int: # ✅ Greedy(3)
        # 贪心供应
        # 输入：QKDNetwork、路由方案
        # 输出：Time Slot
        # 估计最小的完成时间隙数量，每个请求都按照最短路
        count = 0
        # print("来这了 greedy_ps_pro")
        self.net.clearLoad()
        def addFlowOnPath(route): # 施加流
            path = route["path"]
            for i in range(len(path)-1):
                link_start = path[i]
                link_end = path[i + 1]
                self.net.G[link_start][link_end]["load"] += route["sd"][2]
        for tmpRoute in routeList:
            addFlowOnPath(tmpRoute)
        max_time_slot = 0
        while(True):
            flag = True
            for i in self.net.G.nodes: #初始化
                node = self.net.G.nodes[i]
                if(node["transmitter"] == 1):
                    # 选择一个需要密钥最多的链路进行服务
                    not_zero_link_node_ids = []
                    for neighbor in self.net.G.neighbors(i):
                        if(self.net.G[i][neighbor]["load"] > 0):
                            not_zero_link_node_ids.append(neighbor)
                            flag = False
                            if(count >= 1000):
                                print(1)
                    if(len(not_zero_link_node_ids) != 0):
                        random_neighbor = random.choice(not_zero_link_node_ids)
                        self.net.G[random_neighbor][i]["load"] -= node["transmitter_rate"]
            if(flag):
                break
            else:
                if(max_time_slot >= 1000):
                    print(1)
                max_time_slot += 1
        return max_time_slot

    def cplex_ps_pro(self, routeList) -> int: # ✅ PSP-HE(4)
        # 使用 Cplex 对密钥供给方案进行求解、传入路由方案，传出密钥供给的时隙数量
        try:
            t = getPSPHETimeSlot(routeList, self.net)
        except Exception as e:
            print("cplex_ps_pro", e)
        return t
    
    def sp_gd_with_spd(self): # ✨
        # 最短路中继链路 + 贪心供应 SPS + Greedy
        ret = self.greedy_ps_pro_with_spd(self.shortest_path_spd())
        # print("sp_gd", ret)
        return ret
    
    def sp_gd(self): # ✨
        # 最短路中继链路 + 贪心供应 SPS + Greedy
        # print("来这了 sp_gd")
        ret = self.greedy_ps_pro(self.shortest_path())
        # print("sp_gd", ret)
        return ret

    def sp_psp(self): # ✨ 
        # 最短路中继链路 + 优化供应 SP + PSP-HE
        ret = self.cplex_ps_pro(self.shortest_path())
        # print("sp_psp", ret)
        return ret
    
    def rpd_gd(self): # ✨ 
        # 优化中继链路 + 贪心供应 RPS-HE + GD
        ret = self.greedy_ps_pro(self.cplex_path())
        # print("rpd_gd", ret)
        return ret

    def rpsp(self): # ✨ 
        # 优化中继链路 + 优化供应 RPSP-HE
        ret = self.cplex_ps_pro(self.cplex_path())
        # print("rpsp", ret)
        return ret
    
